# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation

import copy
from dataclasses import dataclass
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple, Union

from codegen.cmake_config import *
from codegen.cpp_symbol_map import *

from codegen.ops.fmha_fwd import (
    FmhaFwdTileSize,
    FmhaFwdApiTrait,
    FMHA_FWD_KERNEL_HEADER,
    FMHA_FWD_API_PER_DTYPE,
    FMHA_FWD_API_PER_HDIM_CASE,
)


DTYPE_BITS = {
    "fp32": 32,
    "fp16": 16,
    "bf16": 16,
    "fp8" : 8,
    "bf8" : 8
}

K0_MAX_SUBMAX_MAP = {
    32 : 32,
    64 : 64,
    96 : 128,
    128: 128,
    # 160: 160,
    256: 256
}

FMHA_FWD_SPLITKV_PIPELINE_MAP = {
    "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS",
    "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS",
    "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync",
}

FMHA_FWD_SPLITKV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
using fmha_mask_{F_idx} = {F_mask};

namespace {{
template <bool kHasUnevenSplits, bool kMergeNumHeadGroupsSeqLenQ = false>
struct instance {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;

using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
                                          ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
                                          ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
                                          ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
                                          ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
                                          {F_vlayout}>;

using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
                                                     {F_skpad},
                                                     {F_dpad},
                                                     {F_dvpad},
                                                     {F_logits},
                                                     {F_bias},
                                                     /*kHasBiasGrad=*/false,
                                                     {F_lse},
                                                     {F_squant},
                                                     {F_pagedkv},
                                                     kHasUnevenSplits,
                                                     kMergeNumHeadGroupsSeqLenQ,
                                                     {F_occupancy}>;

using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
    fmha_shape,
    {F_mode},
    fmha_variant_{F_idx},
    fmha_mask_{F_idx},
    fmha_trait>;

using fmha_pipeline = {F_pipeline}<
    fmha_pipeline_problem>;

/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving
///        store_tile_raw() data corruption issue
using fmha_epilogue =
    ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
                                           typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
                                           false, false>>;

using fmha_kernel =
    ck_tile::FmhaFwdSplitKVKernel<fmha_pipeline, fmha_epilogue>;

static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
    using k_ = fmha_kernel;
    auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
    constexpr dim3 blocks             = k_::BlockSize();
    constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
    ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
}};
}}

using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
                        {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
                        {F_dvpad}>;

#include <iostream>

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wtautological-compare"

namespace {{
template <bool kHasUnevenSplits>
void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{
    if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
                  && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
                      || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
        if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
            instance<kHasUnevenSplits, /*kMergeNumHeadGroupsSeqLenQ=*/true>::run(s, a);
        }} else {{
            instance<kHasUnevenSplits>::run(s, a);
        }}
    }} else {{
        instance<kHasUnevenSplits>::run(s, a);
    }}
}}
}} // anonymous namespace

#pragma clang diagnostic pop

template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
    if constexpr({F_mode} == false) {{ // batch mode
        // we don't check every seqlen_k values for kvcache
        if (a.seqlen_k_ptr != nullptr) {{
            run_instance</*kHasUnevenSplits=*/true>(s, a);
        // make sure F_bn0 is divisible by F_bk1
        }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
            run_instance</*kHasUnevenSplits=*/false>(s, a);
        }} else {{
            run_instance</*kHasUnevenSplits=*/true>(s, a);
        }}
    }} else {{
        run_instance</*kHasUnevenSplits=*/true>(s, a);
    }}
}}

template<>
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
{{
    using k_ = instance<true>::fmha_kernel; /// FIXME: choose real kernel type
    return k_::GetName();
}}
"""

FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};

namespace {{
template <ck_tile::index_t kLogMaxSplits>
struct instance {{
using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad},
                                                    {F_dvpad},
                                                    {F_lse},
                                                    {F_squant},
                                                    kLogMaxSplits,
                                                    {F_occupancy}>;

using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
    typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
    {F_hdim},
    {F_mode},
    {F_bn1},
    fmha_trait>;

using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline<
    fmha_pipeline_problem>;

/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving
///        store_tile_raw() data corruption issue
using fmha_epilogue =
    ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
                                           typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
                                           false, false>>;

using fmha_kernel =
    ck_tile::FmhaFwdSplitKVCombineKernel<fmha_pipeline, fmha_epilogue>;

static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
    using k_ = fmha_kernel;
    auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
    constexpr dim3 blocks             = k_::BlockSize();
    constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
    ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
}};
}}

using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1},
                        {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;

#include <iostream>

template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
    if (a.num_splits <= 8) {{
        instance<3>::run(s, a);
    }} else if (a.num_splits <= 16) {{
        instance<4>::run(s, a);
    }} else if (a.num_splits <= 32) {{
        instance<5>::run(s, a);
    }} else if (a.num_splits <= 64) {{
        instance<6>::run(s, a);
    }} else if (a.num_splits <= 128) {{
        instance<7>::run(s, a);
    }}
}}

template<>
std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}>()
{{
    using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type
    return k_::GetName();
}}
"""

FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp"
FMHA_FWD_SPLITKV_API="""
#include <iostream>

template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
    if(s.log_level_ > 0)
    std::cout
    << ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
    << ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
    << std::flush;

    return ck_tile::launch_kernel(s,
        [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
        [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
    );
}}

float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{
    float r = -1;
{F_dispatch}
    return r;
}}
"""

FMHA_FWD_SPLITKV_API_INNER_DISPATCH="""            {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
                        ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
                using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;

                // get combine kernel tile sizes
                using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
                constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, /*F_bn1=*/32>::kM0;

                // make sure we can reuse the padding flags in combine kernels
                static_assert({F_bm0} % kM0 == 0);
                static_assert({F_bn1} % 32 == 0);

                if (t.has_lse) {{
                    if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
                        return -1;
                    }} else {{
                        using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;

                        return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
                    }}
                }} else {{
                    using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>;

                    return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
                }}
            }}
"""

@dataclass
class FmhaFwdSplitKVApiTrait:
    pipeline_tag : str
    # sync with fmha_fwd_traits<>, to generate fallback calls
    hdim      : str
    dtype     : str  # data type
    mode      : str  # value from MODE_MAP
    bm0       : int  # tile size along q seqlen (block size)
    bn0       : int  # tile size along qk seqlen
    bk0       : int  # tile size along qk gemm unroll
    bn1       : int  # tile size along v head_dim
    bk1       : int  # tile size along kv gemm unroll
    bk0max    : int
    vlayout   : str
    mask      : str
    logits    : str
    bias      : str  #
    lse       : str  #
    squant    : str  #
    spad      : str
    skpad     : str
    dpad      : str
    dvpad     : str
    pagedkv   : str

    @property
    def name(self) -> str:
        return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
                    f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
                    f'{self.dvpad}-{self.pagedkv}'

    @property
    def scheck(self) -> str:
        if self.mode == 'group': return 'true/*group mode spad always true*/'                  # group mode only generate spad/skpad == true
        if self.pipeline_tag == 'qr_async':
            if self.spad == 't' : return 'true' # always support
            else :                return 'true'
        elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
            if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/'  # TODO: order of get_pipelines() matters! (ugly)
            else :                return f'a.seqlen_q % {self.bm0} == 0'
        else: assert False

    @property
    def skcheck(self) -> str:
        if self.mode == 'group': return 'true/*group mode skpad always true*/'                  # group mode only generate spad/skpad == true
        if self.pipeline_tag == 'qr_async':
            if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
            else :                 return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
        elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
            if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
            else :                return f'a.seqlen_k % {self.bn0} == 0'
        else: assert False

    @property
    def dcheck(self) -> str:
        if self.pipeline_tag == 'qr_async':
            vec = int((32 * 4) / DTYPE_BITS[self.dtype])
            if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
            else :               assert False
        elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
            bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
            if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
            else :               return f'a.hdim_q % {bk0submax} == 0'
        else:   assert False

    @property
    def dvcheck(self) -> str:
        if self.pipeline_tag == 'qr_async':
            vec = int((32 * 4) / DTYPE_BITS[self.dtype])
            if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
            else :                assert False
        elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
            bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
            if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
            else :                return f'a.hdim_v % {bk0submax} == 0'
        else:   assert False

@dataclass
class FmhaFwdSplitKVPipeline:
    tag : str

    F_vlayout   : str  # row/col
    F_spad      : str  # true/false
    F_skpad     : str  #
    F_dpad      : str  #
    F_dvpad     : str  #
    F_logits    : str  # t/f
    F_bias      : str  # true/false
    F_lse       : str  #
    F_squant    : str  #
    F_pagedkv   : str  # t/f
    F_mask      : str  # value from MASK_MAP

    @property
    def name(self) -> str:
        def pad_name() -> str:
            n = ''
            if self.F_spad == 't': n += 's'
            if self.F_skpad == 't' : n += 'sk'
            if self.F_dpad == 't' : n += 'd'
            if self.F_dvpad == 't' : n += 'dv'
            if n != '' : n = 'p' + n
            return n
        pn = pad_name()
        n = f'{self.tag}_v{self.F_vlayout[0]}'
        if pn != '' : n += f'_{pn}'
        else: n += '_npad'

        if self.F_logits == 't' : n += '_logits'
        else: n += '_nlogits'

        if self.F_bias != 'no' : n += f'_{self.F_bias}'
        else: n += '_nbias'

        if self.F_mask[0:2] == 's_':
            if self.F_mask == 's_mask': n += f'_mask'
            else: n += '_nmask'
        else:
            if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
            else: n += '_nmask'

        if self.F_lse == 't' : n += '_lse'
        else: n += '_nlse'

        if self.F_squant == 't' : n += '_squant'
        else: n += '_nsquant'

        if self.F_pagedkv == 't' : n += '_pagedkv'
        else: n += '_npagedkv'
        return n

@dataclass
class FmhaFwdSplitKVCombinePipeline:
    tag : str

    F_spad      : str  # true/false
    F_dvpad     : str  #
    F_lse       : str  #
    F_squant    : str  #

    @property
    def name(self) -> str:
        def pad_name() -> str:
            n = ''
            if self.F_spad == 't': n += 's'
            if self.F_dvpad == 't' : n += 'dv'
            if n != '' : n = 'p' + n
            return n
        pn = pad_name()
        n = f'{self.tag}'
        if pn != '' : n += f'_{pn}'
        else: n += '_npad'

        if self.F_lse == 't' : n += '_lse'
        else: n += '_nlse'

        if self.F_squant == 't' : n += '_squant'
        else: n += '_nsquant'
        return n

class FmhaFwdSplitKVApiPool:
    def __init__(self, mask_impl):
        self.pool = dict()
        self.mask_impl = mask_impl

    def register_traits(self, trait : FmhaFwdSplitKVApiTrait) -> None:
        # TODO: do we need to check duplication?
        if trait.dtype not in self.pool.keys():
            self.pool[trait.dtype] = dict()
        if trait.hdim not in self.pool[trait.dtype].keys():
            self.pool[trait.dtype][trait.hdim] = list()

        self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))

    @property
    def api(self) -> str:
        per_dtypes=str()
        for i, dtype in enumerate(self.pool.keys()):
            per_hdim_case=str()
            for j, hdim in enumerate(self.pool[dtype].keys()):
                traits=self.pool[dtype][hdim]
                inners=str()
                for k, trait in enumerate(traits):
                    if_k = 'if' if k == 0 else 'else if'
                    inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
                                   F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
                                   F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
                                   F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
                                   F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
                                   F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
                                   F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
                                   F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
                if_j = 'if' if j == 0 else 'else if'
                per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
            if_i = 'if' if i == 0 else 'else if'
            per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
        if not per_dtypes:
            # empty string we add some ignore to suppress warning in api
            per_dtypes += '    (void)t ; (void)s ; (void)a;'
        return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes)

@dataclass
class FmhaFwdSplitKVCombineTileSize:
    F_bn1       : int  # tile size along v head_dim
    F_occupancy : int  # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
    @property
    def name(self) -> str:
        return f"b{self.F_bn1}" +\
            ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")

@dataclass
class FmhaFwdSplitKVKernel:
    F_idx           : int  # this is not a tunable, but a counter to differentiate symbol
    F_hdim          : int  # hdim
    F_dtype         : str  # data type
    F_mode          : str  # value from MODE_MAP
    F_tile          : FmhaFwdTileSize
    F_pipeline      : FmhaFwdSplitKVPipeline
    mask_impl       : str

    @property
    def template(self) -> str:
        kernel_body = str()
        return FMHA_FWD_KERNEL_HEADER + \
            FMHA_FWD_SPLITKV_KERNEL_BODY.format(
                F_idx           = self.F_idx,
                F_hdim          = self.F_hdim,
                F_dtype         = FWD_DTYPE_MAP[self.F_dtype],
                F_bm0           = self.F_tile.F_bm0,
                F_bn0           = self.F_tile.F_bn0,
                F_bk0           = self.F_tile.F_bk0,
                F_bn1           = self.F_tile.F_bn1,
                F_bk1           = self.F_tile.F_bk1,
                F_bk0max        = self.F_tile.F_bk0max,
                F_rm0           = self.F_tile.F_rm0,
                F_rn0           = self.F_tile.F_rn0,
                F_rk0           = self.F_tile.F_rk0,
                F_rm1           = self.F_tile.F_rm1,
                F_rn1           = self.F_tile.F_rn1,
                F_rk1           = self.F_tile.F_rk1,
                F_wm0           = self.F_tile.F_wm0,
                F_wn0           = self.F_tile.F_wn0,
                F_wk0           = self.F_tile.F_wk0,
                F_wm1           = self.F_tile.F_wm1,
                F_wn1           = self.F_tile.F_wn1,
                F_wk1           = self.F_tile.F_wk1,
                F_vlayout       = LAYOUT_MAP[self.F_pipeline.F_vlayout],
                F_spad          = BOOL_MAP[self.F_pipeline.F_spad],
                F_skpad         = BOOL_MAP[self.F_pipeline.F_skpad],
                F_dpad          = BOOL_MAP[self.F_pipeline.F_dpad],
                F_dvpad         = BOOL_MAP[self.F_pipeline.F_dvpad],
                F_logits        = BOOL_MAP[self.F_pipeline.F_logits],
                F_bias          = BIAS_MAP[self.F_pipeline.F_bias],
                F_lse           = BOOL_MAP[self.F_pipeline.F_lse],
                F_squant        = BOOL_MAP[self.F_pipeline.F_squant],
                F_pagedkv       = BOOL_MAP[self.F_pipeline.F_pagedkv],
                F_occupancy     = self.F_tile.F_occupancy,
                F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
                F_mask          = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
                F_mode          = MODE_MAP[self.F_mode],
                F_pipeline      = FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag])

    @property
    def name(self) -> str:
        # TODO: we don't encode idx here
        return f"fmha_fwd_splitkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
                self.F_tile.name + '_' + self.F_pipeline.name

    @property
    def filename(self) -> str:
        return self.name + ".cpp"

    def api_trait(self) -> FmhaFwdSplitKVApiTrait:
        return FmhaFwdSplitKVApiTrait(
                pipeline_tag=self.F_pipeline.tag,
                hdim=str(self.F_hdim),
                dtype=self.F_dtype,
                mode=self.F_mode,
                bm0=self.F_tile.F_bm0,
                bn0=self.F_tile.F_bn0,
                bk0=self.F_tile.F_bk0,
                bn1=self.F_tile.F_bn1,
                bk1=self.F_tile.F_bk1,
                bk0max=self.F_tile.F_bk0max,
                vlayout=self.F_pipeline.F_vlayout,
                logits=self.F_pipeline.F_logits,
                mask=self.F_pipeline.F_mask,
                bias=self.F_pipeline.F_bias,
                lse=self.F_pipeline.F_lse,
                squant=self.F_pipeline.F_squant,
                pagedkv=self.F_pipeline.F_pagedkv,
                spad=self.F_pipeline.F_spad,
                skpad=self.F_pipeline.F_skpad,
                dpad=self.F_pipeline.F_dpad,
                dvpad=self.F_pipeline.F_dvpad)

@dataclass
class FmhaFwdSplitKVCombineKernel:
    F_idx           : int  # this is not a tunable, but a counter to differentiate symbol
    F_hdim          : int  # hdim
    F_dtype         : str  # data type
    F_mode          : str  # value from MODE_MAP
    F_tile          : FmhaFwdSplitKVCombineTileSize
    F_pipeline      : FmhaFwdSplitKVCombinePipeline

    @property
    def template(self) -> str:
        kernel_body = str()
        return FMHA_FWD_KERNEL_HEADER + \
            FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format(
                F_idx           = self.F_idx,
                F_hdim          = self.F_hdim,
                F_dtype         = FWD_DTYPE_MAP[self.F_dtype],
                F_bn1           = self.F_tile.F_bn1,
                F_spad          = BOOL_MAP[self.F_pipeline.F_spad],
                F_dvpad         = BOOL_MAP[self.F_pipeline.F_dvpad],
                F_lse           = BOOL_MAP[self.F_pipeline.F_lse],
                F_squant        = BOOL_MAP[self.F_pipeline.F_squant],
                F_occupancy     = self.F_tile.F_occupancy,
                F_mode          = MODE_MAP[self.F_mode])

    @property
    def name(self) -> str:
        # TODO: we don't encode idx here
        return f"fmha_fwd_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
                self.F_tile.name + '_' + self.F_pipeline.name

    @property
    def filename(self) -> str:
        return self.name + ".cpp"

# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
    if dtype == 'fp16' or dtype == 'bf16':
        return {
            '32'  : FmhaFwdTileSize(32, 64,  16, 32,  32,  32,   2, 1, 1,  2, 1, 1,  16, 16, 16,  16, 16, 16,  -1),
            '64'  : FmhaFwdTileSize(64, 64,  32, 64,  32,  64,   4, 1, 1,  4, 1, 1,  16, 16, 16,  16, 16, 16,  -1),
        ### '96'  : FmhaFwdTileSize(64, 128, 32, 128, 32,  96,   4, 1, 1,  4, 1, 1,  16, 16, 16,  16, 16, 16,  -1),
            '128' : FmhaFwdTileSize(64, 128, 32, 128, 32,  128,  4, 1, 1,  4, 1, 1,  16, 16, 16,  16, 16, 16,  -1),
        ### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32,  160,  4, 1, 1,  4, 1, 1,  16, 16, 16,  16, 16, 16,  -1),
            '256' : FmhaFwdTileSize(64, 128, 32, 256, 32,  256,  4, 1, 1,  4, 1, 1,  16, 16, 16,  16, 16, 16,  -1),
        }
    elif dtype == 'fp8' or dtype == 'bf8':
        return {
            '64'  : FmhaFwdTileSize(128, 64,  32, 64,  32,  64,   2, 1, 1,  2, 1, 1,  32, 32, 32,  32, 32, 32,  -1),
            '128' : FmhaFwdTileSize(128, 128, 32, 128, 32,  128,  4, 1, 1,  4, 1, 1,  32, 32, 32,  32, 32, 32,  -1),
            '256' : FmhaFwdTileSize(128, 128, 32, 256, 32,  256,  4, 1, 1,  4, 1, 1,  32, 32, 32,  32, 32, 32,  -1),
        }
    else:
        return None

def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
    if dtype == 'fp16' or dtype == 'bf16':
        return {
            '32'  : FmhaFwdSplitKVCombineTileSize(32,  -1),
            '64'  : FmhaFwdSplitKVCombineTileSize(32,  -1),
        ### '96'  : FmhaFwdSplitKVCombineTileSize(32,  -1),
            '128' : FmhaFwdSplitKVCombineTileSize(32,  -1),
        ### '160' : FmhaFwdSplitKVCombineTileSize(32,  -1),
            '256' : FmhaFwdSplitKVCombineTileSize(32,  -1),
    }
    elif dtype == 'fp8' or dtype == 'bf8':
        return {
            '64'  : FmhaFwdSplitKVCombineTileSize(32,  -1),
            '128' : FmhaFwdSplitKVCombineTileSize(32,  -1),
            '256' : FmhaFwdSplitKVCombineTileSize(32,  -1),
        }
    else:
        return None

def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]:
    Pipeline = FmhaFwdSplitKVPipeline
    Kernel = FmhaFwdSplitKVKernel

    # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
    #       support this in future
    def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]:
        # this function will populate a list possible pipelines
        # TODO: the order of List matters! the later in this list will be also be checked later
        # TODO: currently for qr pipeline, let 't' padding to appear later!!
        # TODO: how to design this more generic?
        squant = 't' if dtype == 'fp8' else 'f'
        pipelines = []
        if dtype in ['fp16', 'bf16']:
            for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
                # TODO: use async pipeline when compiler is more stable
                if hdim == 256 or hdim in [32, 64, 128]:         ### [32, 64, 96, 128, 160]:
                # if True:
                    pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
                    pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))

                    pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
                    pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))

                    pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
                    pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))

                    pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
                    pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
                else:
                    pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask))
                    pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
                    pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask))
                    pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
                    if receipt == 1:
                        pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
                        pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim
        elif dtype in ['fp8', 'bf8']:
            for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
                pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 't', squant, 'f', mask))
        elif dtype in ['fp8fp16', 'fp8bf16']:
            # TODO
            None
        else:
            assert False
        return pipelines

    gen = list()
    api_pool = FmhaFwdSplitKVApiPool(mask_impl)

    for dtype in FWD_DTYPE_MAP.keys():
        d = get_fmha_fwd_tile_dict_from_dtype(dtype)
        if d == None:
            continue
        #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
        for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
            tile = d[hdim_str]
            hdim = int(hdim_str)
            for pipeline in get_pipelines(dtype, hdim):
                if mode == "group":
                    if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
                        # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
                        continue
                # logits_soft_cap is only allowed if no bias
                if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
                    continue
                k = Kernel(F_idx=0,
                           F_hdim=hdim,
                           F_dtype=dtype,
                           F_mode=mode,
                           F_tile=tile,
                           F_pipeline=pipeline,
                           mask_impl=mask_impl)
                if kernel_filter != '':
                    if not fnmatch.fnmatch(k.name, kernel_filter):
                        continue
                # Flash attention integration
                if receipt == 2:
                    cond = dtype in ['fp16', 'bf16']
                    cond &= pipeline.F_vlayout == 'row'
                    cond &= pipeline.F_bias in ['no', 'alibi']
                    cond &= pipeline.F_squant == 'f'
                    if not cond:
                        continue
                # PyTorch integration
                elif receipt == 4:
                    cond = dtype in ['fp16, bf16']
                    cond &= pipeline.F_vlayout == 'row'
                    cond &= pipeline.F_bias in ['no', 'bias']
                    cond &= pipeline.F_squant == 'f'
                    cond &= mode == 'batch'
                    if not cond:
                        continue
                # Aiter(mha_varlen_fwd) integration
                elif receipt == 200:
                    cond = dtype in ['fp16', 'bf16']
                    cond &= mode == "group"
                    cond &= pipeline.F_vlayout == 'row'
                    cond &= pipeline.F_squant == 'f'
                    if not cond:
                        continue
                # aiter::mha_fwd_splikv C++ api integration
                elif receipt == 600:
                    cond = dtype in ['fp16', 'bf16']
                    cond &= pipeline.F_vlayout == 'row'
                    cond &= pipeline.F_squant == 'f'
                    if not cond:
                        continue
                api_pool.register_traits(k.api_trait())
                gen.append(k)

    return (api_pool, gen)

def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaFwdSplitKVCombineKernel]:
    Pipeline = FmhaFwdSplitKVCombinePipeline
    Kernel = FmhaFwdSplitKVCombineKernel

    # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
    #       support this in future
    def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]:
        # this function will populate a list possible pipelines
        # TODO: the order of List matters! the later in this list will be also be checked later
        # TODO: currently for qr pipeline, let 't' padding to appear later!!
        # TODO: how to design this more generic?
        squant = 't' if dtype == 'fp8' else 'f'
        pipelines = []
        if dtype in ['fp16', 'bf16']:
            for spad, dvpad, lse in itertools.product(["t", "f"], ["t", "f"], ["t", "f"]):
                pipelines.append(Pipeline('unused', spad, dvpad, lse, squant))
        elif dtype in ['fp8', 'bf8']:
            # no need lse kernels
            pipelines.append(Pipeline('unused', 'f', 'f', 'f', squant))
        else:
            assert False
        return pipelines

    gen = list()

    for dtype in FWD_DTYPE_MAP.keys():
        d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype)
        if d == None:
            continue
        #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
        for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
            tile = d[hdim_str]
            hdim = int(hdim_str)
            for pipeline in get_pipelines(dtype, hdim):
                if mode == "group":
                    if pipeline.F_spad != 't':
                        # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
                        continue
                k = Kernel(F_idx=0,
                           F_hdim=hdim,
                           F_dtype=dtype,
                           F_mode=mode,
                           F_tile=tile,
                           F_pipeline=pipeline)
                if kernel_filter != '':
                    if not fnmatch.fnmatch(k.name, kernel_filter):
                        continue
                # Aiter(mha_varlen_fwd) integration
                if receipt == 200:
                    cond = dtype in ['fp16', 'bf16']
                    cond &= mode == "group"
                    if not cond:
                        continue
                # aiter::mha_fwd_splikv C++ api integration
                elif receipt == 600:
                    cond = dtype in ['fp16', 'bf16']
                    if not cond:
                        continue
                gen.append(k)

    return gen

def write_single_kernel(kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None:
    (autogen_dir / kernel.filename).write_text(kernel.template)

def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None:
    file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME
    file_path.write_text(api_pool.api)

def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
    filter_list = filter_list.split('@')
    filter_list.extend([''] * (2 - len(filter_list)))
    assert optdim_list == [-1]

    kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt)
    for kernel in kernels:
        write_single_kernel(kernel, output_dir)
    api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl)
    for kernel in kernels:
        write_single_kernel(kernel, output_dir)
    write_fwd_splitkv_api(api_pool, output_dir)

def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
    filter_list = filter_list.split('@')
    filter_list.extend([''] * (2 - len(filter_list)))
    assert optdim_list == [-1]

    with file_path.open('a') as f:
        kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt)
        for kernel in kernels:
            f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
        _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl)
        for kernel in kernels:
            f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
        f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n")
